import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from tqdm import tqdm


def attack(model, dl, batch_size=64, steps=10, eps=0.3, seed=1):    
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    attack_x, targets = pgd_attack(model, dl, device, steps, eps)
    return attack_x, targets


def pgd_attack(model, dl, device, steps=10, eps=0.3):
    correct_s = 0
    with torch.no_grad():
        for (x, y) in dl:
            x, y = x.to(device), y.to(device)
            correct_s += test_step(model, x, y)
    acc = 100. * correct_s / len(dl.dataset)
    print(f'Standard Acc : {acc}')
   
    adv_xs = []
    adv_ys = []
    correct = 0
    for (x, y) in tqdm(dl):
        x, y = x.to(device), y.to(device)
        grad = gen_grad(model, x, y)
        adv_x = pgd(model, x, y, steps=steps, eps=eps/steps)
        correct += test_step(model, adv_x, y)
        
        adv_xs.append(adv_x)
        adv_ys.append(y)
    test_acc = 100. * correct / len(dl.dataset)
    print(f'Adv. Robust Acc : {test_acc}')
    return torch.cat(adv_xs).squeeze(), torch.cat(adv_ys).squeeze()


def gen_grad(model, x, y):
    '''
        Generate the gradient of the loss function.
    '''
    model.eval()
    x.requires_grad = True

    # Define gradient of loss wrt input
    logits = model(x)
    loss = F.cross_entropy(logits, y, reduction='mean')
    model.zero_grad()
    loss.backward()
    grad = x.grad.data
    return grad


def fgsm(x, grad, eps=0.3, clipping=True):
    '''
        FGSM attack.
    '''
    # Add perturbation to original example to obtain adversarial example
    adv_x = x.detach() + eps * grad.detach().sign()
    if clipping:
        adv_x = torch.clamp(adv_x, -1, 1)
    return adv_x


def pgd(model, x, y, steps, eps):
    '''
        I-FGSM attack.
    '''
    adv_x = x
    # iteratively apply the FGSM with small step size
    for i in range(steps):
        grad = gen_grad(model, adv_x, y)
        adv_x = fgsm(adv_x, grad, eps)
    return adv_x


def test_step(model, data, labels):
    model.eval()
    logits = model(data)

    # Prediction for the test set
    preds = logits.max(1)[1]
    return preds.eq(labels).sum().item()